import numpy as np
import xml.etree.ElementTree as ET
import os.path as osp
import os

def intersection_over_origin(target, origin):
    """calculate intersection of two boundboxes over the origin one.

    Args:
        target: list of xmin, ymin, xmax, ymax;
        origin: list of xmin, ymin, xmax, ymax;
    Returns:
        a float number of ioo between two inputs.
    """
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(target[0], origin[0])
    yA = max(target[1], origin[1])
    xB = min(target[2], origin[2])
    yB = min(target[3], origin[3])

    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

    OriginBoxArea = (origin[2] - origin[0] + 1) * (origin[3] - origin[1] + 1)

    ioo = interArea / float(OriginBoxArea)

    return ioo

def voc_basic_xml_element_generator(filename, folder, size_w, size_h, size_c):
    '''
    description: generat some common configs of a VOC label
    param {*}
    return {*}
    '''
    fn = ET.Element('filename')
    fn.text = filename
    fd = ET.Element('folder')
    fd.text = folder
    sz = ET.Element('size')

    width = ET.Element('width')
    width.text = str(size_w)
    height = ET.Element('height')
    height.text = str(size_h)
    depth = ET.Element('depth')
    depth.text = str(size_c)
    for ele in (width, height, depth):
        sz.append(ele)

    return (fn, fd, sz)

def voc_object_xml_element_resolver(object_element):
    '''
    description: convert the object element of a VOC xml into a dict
    param:
        object_element: the object to convert
    return:
        dict_obj: converted dict
    '''
    box = object_element.find('bndbox')
    dict_obj = {
        'name' : object_element.findtext('name'),
        'pose' : object_element.findtext('pose'),
        'truncated' : object_element.findtext('truncated'),
        'difficult' : object_element.findtext('difficult'),
        'bndbox' : {
            'xmin' : box.findtext('xmin'),
            'ymin' : box.findtext('ymin'),
            'xmax' : box.findtext('xmax'),
            'ymax' : box.findtext('ymax'),
            'x1' : box.findtext('x1'),
            'y1' : box.findtext('y1'),
            'x2' : box.findtext('x2'),
            'y2' : box.findtext('y2'),
            'x3' : box.findtext('x3'),
            'y3' : box.findtext('y3'),
            'x4' : box.findtext('x4'),
            'y4' : box.findtext('y4'),
        }
    }
    return dict_obj


def voc_object_xml_element_generator(object_dict):
    '''
    description: convert the dict of a VOC object into a XML element
    param:
        object_dict: the dict to convert
    return:
        obj: converted element
    '''
    obj = ET.Element('object')
    for k, v in object_dict.items():
        obj_ele = ET.Element(k)
        if k != 'bndbox':
            obj_ele.text = v
        else:
            for k1, v1 in v.items():
                obj_cord = ET.Element(k1)
                obj_cord.text = str(v1)
                obj_ele.append(obj_cord)
        obj.append(obj_ele)
    return obj
        
    

def voc_subimg_box_generator(input_w, input_h, subimg_w, subimg_h, overlap_w, overlap_h):
    '''
    description: generate locations of the sub-images
    param {*}
    return {*}
    '''
    offset_w = subimg_w - overlap_w
    offset_h = subimg_h - overlap_h
    # 对原始图像进行分割
    # TODO:将多余的数据也剪裁出来
    boxes = []
    step_w, step_h = (0, 0)
    while step_h * offset_h + subimg_h <= input_h:
        step_w = 0
        while step_w * offset_w + subimg_w <= input_w:
            # (xmin, ymin, xmax, ymax)
            subimg = (step_w * offset_w, step_h * offset_h, step_w * offset_w + subimg_w, step_h * offset_h + subimg_h)
            boxes.append(subimg)
            step_w += 1
        step_h += 1
    return ((step_w, step_h), boxes)

def voc_label_preprocess(input_anno):
    '''
    description: extract basic info and objects in a annotation file seperately
    param {*}
    return {*}
    '''
    # 读取标注文件
    anno_tree = ET.parse(input_anno)
    anno_root = anno_tree.getroot()    
    # 缓存初始化
    anno_info = []
    anno_obj = []
    for child in anno_root:
        # 提取出标注的目标
        if child.tag == 'object':
            anno_obj.append(child)
        # 提取出不需要改变的部分
        else:#elif child.tag != 'filename' and child.tag != 'size' and child.tag != 'path' and child.tag != 'folder':
            anno_info.append(child)
    return (anno_info, anno_obj)

def is_small_box_in_big(small_box, big_box):
    '''
    smal_box:tuple(xmin, ymin, xmax, ymax)
    big_box:tuple(xmin, ymin, xmax, ymax)
    '''
    if small_box[0] < big_box[0]:
        return False
    if small_box[1] < big_box[1]:
        return False
    if small_box[2] > big_box[2]:
        return False
    if small_box[3] > big_box[3]:
        return False
    return True

def lefttop_rightdown_2_center_width(bbox):
    '''
    bbox: tuple(xmin, ymin, xmax, ymax)
    return: tuple(xmid, ymid, width, height)
    '''
    return ((bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2, bbox[2]-bbox[0], bbox[3]-bbox[1])

def center_width_2_lefttop_rightdown(bbox):
    '''
    bbox: tuple(xmid, ymid, width, height)
    return: tuple(xmin, ymin, xmax, ymax)
    '''
    return (bbox[0]-bbox[2]/2, bbox[1]-bbox[2]/2, bbox[0]+bbox[2]/2, bbox[1]+bbox[2]/2)

def clip_bbox(bbox, im_shape):
    '''
    make sure all the input bbox is in the range of the im_shape and return int number
    param: bbox: tuple(xmin, ymin, xmax, ymax)
    param: im_shape: tuple(im_width, im_height)
    return: the bbox in the im_shape
    '''
    xmin, ymin, xmax, ymax = bbox
    if bbox[0] < 0:
        xmin = 0
    if bbox[1] < 0:
        ymin = 0
    if bbox[2] >= im_shape[0]:
        xmax = im_shape[0] - 1
    if bbox[3] >= im_shape[1]:
        ymax = im_shape[1] - 1
    return (int(xmin), int(ymin), int(xmax), int(ymax))

def __load_voc_annotation(filename):
    '''
    for a given index, load its fundmental information (filename, size) and its bboxes
    return: a dict of related information
    '''
    tree = ET.parse(filename)
    anno_dict = dict()
    obj_list = []
    anno_dict['filename'] = tree.find('filename').text
    anno_dict['folder'] = tree.find('folder').text
    anno_dict['size'] = dict()
    anno_dict['size']['width'] = tree.find('size').find('width').text 
    anno_dict['size']['height'] = tree.find('size').find('height').text 
    anno_dict['size']['depth'] = tree.find('size').find('depth').text 

    objects = tree.findall("object")
    for obj in objects:
        bndbox = obj.find('bndbox')
        obj_dict = {
            'name' : obj.find('name').text,
            'truncated' : obj.find('truncated').text,
            'difficult': obj.find('difficult').text,
            'bndbox':{
                'xmin':bndbox.find('xmin').text,
                'ymin':bndbox.find('ymin').text,
                'xmax':bndbox.find('xmax').text,
                'ymax':bndbox.find('ymax').text,
            }
        }
        obj_list.append(obj_dict)

    anno_dict['object'] = obj_list
    return anno_dict

def load_voc_annotations(anno_root):
    anno_list = []
    anno_file = os.listdir(anno_root)
    for f in anno_file:
        anno_list.append(__load_voc_annotation(osp.join(anno_root, f)))
    return anno_file ,anno_list

def __save_anno_dict(anno_dict, save_path):
    root = ET.Element('annotation')
    tree = ET.ElementTree(root)
    basic_anno = voc_basic_xml_element_generator(anno_dict['filename'], anno_dict['folder'], anno_dict['size']['width'],
                                                    anno_dict['size']['height'], anno_dict['size']['depth'])
    for basic in basic_anno:
        root.append(basic)
   
    for dict_obj in anno_dict['object']:
        ele_obj = voc_object_xml_element_generator(dict_obj)
        root.append(ele_obj)

    tree.write(save_path, encoding='utf-8', xml_declaration=True)

def save_anno_list(anno_files, anno_list, save_root, fork_time=0):

    if not osp.exists(save_root):
        os.mkdir(save_root)
    
    a_list = anno_list.copy()

    for _ in range(fork_time):
        pid = os.fork()
        if pid == 0:
            # child
            a_list = a_list[:int(len(a_list)/2)]
        elif pid > 0:
            # parent
            a_list = a_list[int(len(a_list)/2):]
        else:
            print('fork failed！')
    
    for k, anno_dict in enumerate(a_list):
        __save_anno_dict(anno_dict, osp.join(save_root, anno_files[k]))